from typing import Union, Optional, Dict, Callable, List
import torch
import torch.nn as nn

from ding.torch_utils import get_lstm
from ding.utils import MODEL_REGISTRY, SequenceType, squeeze
from ..common import FCEncoder, ConvEncoder, DiscreteHead, DuelingHead, MultiHead, RainbowHead, \
    QuantileHead, FQFHead, QRDQNHead, DistributionHead, BranchingHead
from ding.torch_utils.network.gtrxl import GTrXL


@MODEL_REGISTRY.register('dqn')
class DQN(nn.Module):
    """
    Overview:
        The neural nework structure and computation graph of Deep Q Network (DQN) algorithm, which is the most classic \
        value-based RL algorithm for discrete action. The DQN is composed of two parts: ``encoder`` and ``head``. \
        The ``encoder`` is used to extract the feature from various observation, and the ``head`` is used to compute \
        the Q value of each action dimension.
    Interfaces:
        ``__init__``, ``forward``.

    .. note::
        Current ``DQN`` supports two types of encoder: ``FCEncoder`` and ``ConvEncoder``, two types of head: \
        ``DiscreteHead`` and ``DuelingHead``. You can customize your own encoder or head by inheriting this class.
    """

    def __init__(
        self,
        obs_shape: Union[int, SequenceType],
        action_shape: Union[int, SequenceType],
        encoder_hidden_size_list: SequenceType = [128, 128, 64],
        dueling: bool = True,
        head_hidden_size: Optional[int] = None,
        head_layer_num: int = 1,
        activation: Optional[nn.Module] = nn.ReLU(),
        norm_type: Optional[str] = None,
        dropout: Optional[float] = None,
        init_bias: Optional[float] = None,
        noise: bool = False,
    ) -> None:
        """
        Overview:
            initialize the DQN (encoder + head) Model according to corresponding input arguments.
        Arguments:
            - obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape, such as 8 or [4, 84, 84].
            - action_shape (:obj:`Union[int, SequenceType]`): Action space shape, such as 6 or [2, 3, 3].
            - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``, \
                the last element must match ``head_hidden_size``.
            - dueling (:obj:`Optional[bool]`): Whether choose ``DuelingHead`` or ``DiscreteHead (default)``.
            - head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of head network, defaults to None, \
                then it will be set to the last element of ``encoder_hidden_size_list``.
            - head_layer_num (:obj:`int`): The number of layers used in the head network to compute Q value output.
            - activation (:obj:`Optional[nn.Module]`): The type of activation function in networks \
                if ``None`` then default set it to ``nn.ReLU()``.
            - norm_type (:obj:`Optional[str]`): The type of normalization in networks, see \
                ``ding.torch_utils.fc_block`` for more details. you can choose one of ['BN', 'IN', 'SyncBN', 'LN']
            - dropout (:obj:`Optional[float]`): The dropout rate of the dropout layer. \
                if ``None`` then default disable dropout layer.
            - init_bias (:obj:`Optional[float]`): The initial value of the last layer bias in the head network. \
            - noise (:obj:`bool`): Whether to use ``NoiseLinearLayer`` as ``layer_fn`` to boost exploration in \
                Q networks' MLP. Default to ``False``.
        """
        super(DQN, self).__init__()
        # Squeeze data from tuple, list or dict to single object. For example, from (4, ) to 4
        obs_shape, action_shape = squeeze(obs_shape), squeeze(action_shape)
        if head_hidden_size is None:
            head_hidden_size = encoder_hidden_size_list[-1]
        # FC Encoder
        if isinstance(obs_shape, int) or len(obs_shape) == 1:
            self.encoder = FCEncoder(
                obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type, dropout=dropout
            )
        # Conv Encoder
        elif len(obs_shape) == 3:
            assert dropout is None, "dropout is not supported in ConvEncoder"
            self.encoder = ConvEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
        else:
            raise RuntimeError(
                "not support obs_shape for pre-defined encoder: {}, please customize your own DQN".format(obs_shape)
            )
        # Head Type
        if dueling:
            head_cls = DuelingHead
        else:
            head_cls = DiscreteHead
        multi_head = not isinstance(action_shape, int)
        if multi_head:
            self.head = MultiHead(
                head_cls,
                head_hidden_size,
                action_shape,
                layer_num=head_layer_num,
                activation=activation,
                norm_type=norm_type,
                dropout=dropout,
                noise=noise,
            )
        else:
            self.head = head_cls(
                head_hidden_size,
                action_shape,
                head_layer_num,
                activation=activation,
                norm_type=norm_type,
                dropout=dropout,
                noise=noise,
            )
            if init_bias is not None and head_cls == DuelingHead:
                # Zero the last layer bias of advantage head
                self.head.A[-1][0].bias.data.fill_(init_bias)

    def forward(self, x: torch.Tensor) -> Dict:
        """
        Overview:
            DQN forward computation graph, input observation tensor to predict q_value.
        Arguments:
            - x (:obj:`torch.Tensor`): The input observation tensor data.
        Returns:
            - outputs (:obj:`Dict`): The output of DQN's forward, including q_value.
        ReturnsKeys:
            - logit (:obj:`torch.Tensor`): Discrete Q-value output of each possible action dimension.
        Shapes:
            - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is ``obs_shape``
            - logit (:obj:`torch.Tensor`): :math:`(B, M)`, where B is batch size and M is ``action_shape``
        Examples:
            >>> model = DQN(32, 6)  # arguments: 'obs_shape' and 'action_shape'
            >>> inputs = torch.randn(4, 32)
            >>> outputs = model(inputs)
            >>> assert isinstance(outputs, dict) and outputs['logit'].shape == torch.Size([4, 6])

        .. note::
            For consistency and compatibility, we name all the outputs of the network which are related to action \
            selections as ``logit``.
        """
        x = self.encoder(x)
        x = self.head(x)
        return x


@MODEL_REGISTRY.register('bdq')
class BDQ(nn.Module):

    def __init__(
            self,
            obs_shape: Union[int, SequenceType],
            num_branches: int = 0,
            action_bins_per_branch: int = 2,
            layer_num: int = 3,
            a_layer_num: Optional[int] = None,
            v_layer_num: Optional[int] = None,
            encoder_hidden_size_list: SequenceType = [128, 128, 64],
            head_hidden_size: Optional[int] = None,
            norm_type: Optional[nn.Module] = None,
            activation: Optional[nn.Module] = nn.ReLU(),
    ) -> None:
        """
        Overview:
            Init the BDQ (encoder + head) Model according to input arguments. \
                referenced paper Action Branching Architectures for Deep Reinforcement Learning \
                <https://arxiv.org/pdf/1711.08946>
        Arguments:
            - obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape, such as 8 or [4, 84, 84].
            - num_branches (:obj:`int`): The number of branches, which is equivalent to the action dimension, \
                such as 6 in mujoco's halfcheetah environment.
            - action_bins_per_branch (:obj:`int`): The number of actions in each dimension.
            - layer_num (:obj:`int`): The number of layers used in the network to compute Advantage and Value output.
            - a_layer_num (:obj:`int`): The number of layers used in the network to compute Advantage output.
            - v_layer_num (:obj:`int`): The number of layers used in the network to compute Value output.
            - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``, \
                the last element must match ``head_hidden_size``.
            - head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of head network.
            - norm_type (:obj:`Optional[str]`): The type of normalization in networks, see \
                ``ding.torch_utils.fc_block`` for more details.
            - activation (:obj:`Optional[nn.Module]`): The type of activation function in networks \
                if ``None`` then default set it to ``nn.ReLU()``
        """
        super(BDQ, self).__init__()
        # For compatibility: 1, (1, ), [4, 32, 32]
        obs_shape, num_branches = squeeze(obs_shape), squeeze(num_branches)
        if head_hidden_size is None:
            head_hidden_size = encoder_hidden_size_list[-1]

        # backbone
        # FC Encoder
        if isinstance(obs_shape, int) or len(obs_shape) == 1:
            self.encoder = FCEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
        # Conv Encoder
        elif len(obs_shape) == 3:
            self.encoder = ConvEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
        else:
            raise RuntimeError(
                "not support obs_shape for pre-defined encoder: {}, please customize your own DQN".format(obs_shape)
            )

        self.num_branches = num_branches
        self.action_bins_per_branch = action_bins_per_branch

        # head
        self.head = BranchingHead(
            head_hidden_size,
            num_branches=self.num_branches,
            action_bins_per_branch=self.action_bins_per_branch,
            layer_num=layer_num,
            a_layer_num=a_layer_num,
            v_layer_num=v_layer_num,
            activation=activation,
            norm_type=norm_type
        )

    def forward(self, x: torch.Tensor) -> Dict:
        """
        Overview:
            BDQ forward computation graph, input observation tensor to predict q_value.
        Arguments:
            - x (:obj:`torch.Tensor`): Observation inputs
        Returns:
            - outputs (:obj:`Dict`): BDQ forward outputs, such as q_value.
        ReturnsKeys:
            - logit (:obj:`torch.Tensor`): Discrete Q-value output of each action dimension.
        Shapes:
            - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is ``obs_shape``
            - logit (:obj:`torch.FloatTensor`): :math:`(B, M)`, where B is batch size and M is
                ``num_branches * action_bins_per_branch``
        Examples:
            >>> model = BDQ(8, 5, 2)  # arguments: 'obs_shape', 'num_branches' and 'action_bins_per_branch'.
            >>> inputs = torch.randn(4, 8)
            >>> outputs = model(inputs)
            >>> assert isinstance(outputs, dict) and outputs['logit'].shape == torch.Size([4, 5, 2])
        """
        x = self.encoder(x) / (self.num_branches + 1)  # corresponds to the "Gradient Rescaling" in the paper
        x = self.head(x)
        return x


@MODEL_REGISTRY.register('c51dqn')
class C51DQN(nn.Module):
    """
    Overview:
        The neural network structure and computation graph of C51DQN, which combines distributional RL and DQN. \
        You can refer to https://arxiv.org/pdf/1707.06887.pdf for more details. The C51DQN is composed of \
        ``encoder`` and ``head``. ``encoder`` is used to extract the feature of observation, and ``head`` is \
        used to compute the distribution of Q-value.
    Interfaces:
        ``__init__``, ``forward``

    .. note::
        Current C51DQN supports two types of encoder: ``FCEncoder`` and ``ConvEncoder``.
    """

    def __init__(
        self,
        obs_shape: Union[int, SequenceType],
        action_shape: Union[int, SequenceType],
        encoder_hidden_size_list: SequenceType = [128, 128, 64],
        head_hidden_size: int = None,
        head_layer_num: int = 1,
        activation: Optional[nn.Module] = nn.ReLU(),
        norm_type: Optional[str] = None,
        v_min: Optional[float] = -10,
        v_max: Optional[float] = 10,
        n_atom: Optional[int] = 51,
    ) -> None:
        """
        Overview:
            initialize the C51 Model according to corresponding input arguments.
        Arguments:
            - obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape, such as 8 or [4, 84, 84].
            - action_shape (:obj:`Union[int, SequenceType]`): Action space shape, such as 6 or [2, 3, 3].
            - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``, \
                the last element must match ``head_hidden_size``.
            - head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of head network, defaults to None, \
                then it will be set to the last element of ``encoder_hidden_size_list``.
            - head_layer_num (:obj:`int`): The number of layers used in the head network to compute Q value output.
            - activation (:obj:`Optional[nn.Module]`): The type of activation function in networks \
                if ``None`` then default set it to ``nn.ReLU()``.
            - norm_type (:obj:`Optional[str]`): The type of normalization in networks, see \
                ``ding.torch_utils.fc_block`` for more details. you can choose one of ['BN', 'IN', 'SyncBN', 'LN']
            - v_min (:obj:`Optional[float]`): The minimum value of the support of the distribution, which is related \
                to the value (discounted sum of reward) scale of the specific environment. Defaults to -10.
            - v_max (:obj:`Optional[float]`): The maximum value of the support of the distribution, which is related \
                to the value (discounted sum of reward) scale of the specific environment. Defaults to 10.
            - n_atom (:obj:`Optional[int]`): The number of atoms in the prediction distribution, 51 is the default \
                value in the paper, you can also try other values such as 301.
        """
        super(C51DQN, self).__init__()
        # For compatibility: 1, (1, ), [4, 32, 32]
        obs_shape, action_shape = squeeze(obs_shape), squeeze(action_shape)
        if head_hidden_size is None:
            head_hidden_size = encoder_hidden_size_list[-1]
        # FC Encoder
        if isinstance(obs_shape, int) or len(obs_shape) == 1:
            self.encoder = FCEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
        # Conv Encoder
        elif len(obs_shape) == 3:
            self.encoder = ConvEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
        else:
            raise RuntimeError(
                "not support obs_shape for pre-defined encoder: {}, please customize your own C51DQN".format(obs_shape)
            )
        # Head Type
        multi_head = not isinstance(action_shape, int)
        if multi_head:
            self.head = MultiHead(
                DistributionHead,
                head_hidden_size,
                action_shape,
                layer_num=head_layer_num,
                activation=activation,
                norm_type=norm_type,
                n_atom=n_atom,
                v_min=v_min,
                v_max=v_max,
            )
        else:
            self.head = DistributionHead(
                head_hidden_size,
                action_shape,
                head_layer_num,
                activation=activation,
                norm_type=norm_type,
                n_atom=n_atom,
                v_min=v_min,
                v_max=v_max,
            )

    def forward(self, x: torch.Tensor) -> Dict:
        """
        Overview:
            C51DQN forward computation graph, input observation tensor to predict q_value and its distribution.
        Arguments:
            - x (:obj:`torch.Tensor`): The input observation tensor data.
        Returns:
            - outputs (:obj:`Dict`): The output of DQN's forward, including q_value, and distribution.
        ReturnsKeys:
            - logit (:obj:`torch.Tensor`): Discrete Q-value output of each possible action dimension.
            - distribution (:obj:`torch.Tensor`): Q-Value discretized distribution, i.e., probability of each \
                uniformly spaced atom Q-value, such as dividing [-10, 10] into 51 uniform spaces.
        Shapes:
            - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is head_hidden_size.
            - logit (:obj:`torch.Tensor`): :math:`(B, M)`, where M is action_shape.
            - distribution(:obj:`torch.Tensor`): :math:`(B, M, P)`, where P is n_atom.
        Examples:
            >>> model = C51DQN(128, 64)  # arguments: 'obs_shape' and 'action_shape'
            >>> inputs = torch.randn(4, 128)
            >>> outputs = model(inputs)
            >>> assert isinstance(outputs, dict)
            >>> # default head_hidden_size: int = 64,
            >>> assert outputs['logit'].shape == torch.Size([4, 64])
            >>> # default n_atom: int = 51
            >>> assert outputs['distribution'].shape == torch.Size([4, 64, 51])

        .. note::
            For consistency and compatibility, we name all the outputs of the network which are related to action \
            selections as ``logit``.

        .. note::
            For convenience, we recommend that the number of atoms should be odd, so that the middle atom is exactly \
            the value of the Q-value.
        """
        x = self.encoder(x)
        x = self.head(x)
        return x


@MODEL_REGISTRY.register('qrdqn')
class QRDQN(nn.Module):
    """
    Overview:
        The neural network structure and computation graph of QRDQN, which combines distributional RL and DQN. \
        You can refer to Distributional Reinforcement Learning with Quantile Regression \
        https://arxiv.org/pdf/1710.10044.pdf for more details.
    Interfaces:
        ``__init__``, ``forward``
    """

    def __init__(
            self,
            obs_shape: Union[int, SequenceType],
            action_shape: Union[int, SequenceType],
            encoder_hidden_size_list: SequenceType = [128, 128, 64],
            head_hidden_size: Optional[int] = None,
            head_layer_num: int = 1,
            num_quantiles: int = 32,
            activation: Optional[nn.Module] = nn.ReLU(),
            norm_type: Optional[str] = None,
    ) -> None:
        """
        Overview:
            Initialize the QRDQN Model according to input arguments.
        Arguments:
            - obs_shape (:obj:`Union[int, SequenceType]`): Observation's space.
            - action_shape (:obj:`Union[int, SequenceType]`): Action's space.
            - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``
            - head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to ``Head``.
            - head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output
            - num_quantiles (:obj:`int`): Number of quantiles in the prediction distribution.
            - activation (:obj:`Optional[nn.Module]`):
                The type of activation function to use in ``MLP`` the after ``layer_fn``,
                if ``None`` then default set to ``nn.ReLU()``
            - norm_type (:obj:`Optional[str]`):
                The type of normalization to use, see ``ding.torch_utils.fc_block`` for more details`
        """
        super(QRDQN, self).__init__()
        # For compatibility: 1, (1, ), [4, 32, 32]
        obs_shape, action_shape = squeeze(obs_shape), squeeze(action_shape)
        if head_hidden_size is None:
            head_hidden_size = encoder_hidden_size_list[-1]
        # FC Encoder
        if isinstance(obs_shape, int) or len(obs_shape) == 1:
            self.encoder = FCEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
        # Conv Encoder
        elif len(obs_shape) == 3:
            self.encoder = ConvEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
        else:
            raise RuntimeError(
                "not support obs_shape for pre-defined encoder: {}, please customize your own QRDQN".format(obs_shape)
            )
        # Head Type
        multi_head = not isinstance(action_shape, int)
        if multi_head:
            self.head = MultiHead(
                QRDQNHead,
                head_hidden_size,
                action_shape,
                layer_num=head_layer_num,
                num_quantiles=num_quantiles,
                activation=activation,
                norm_type=norm_type,
            )
        else:
            self.head = QRDQNHead(
                head_hidden_size,
                action_shape,
                head_layer_num,
                num_quantiles=num_quantiles,
                activation=activation,
                norm_type=norm_type,
            )

    def forward(self, x: torch.Tensor) -> Dict:
        """
        Overview:
            Use observation tensor to predict QRDQN's output.
            Parameter updates with QRDQN's MLPs forward setup.
        Arguments:
            - x (:obj:`torch.Tensor`):
                The encoded embedding tensor with ``(B, N=hidden_size)``.
        Returns:
            - outputs (:obj:`Dict`):
                Run with encoder and head. Return the result prediction dictionary.
        ReturnsKeys:
            - logit (:obj:`torch.Tensor`): Logit tensor with same size as input ``x``.
            - q (:obj:`torch.Tensor`): Q valye tensor tensor of size ``(B, N, num_quantiles)``
            - tau (:obj:`torch.Tensor`): tau tensor of size ``(B, N, 1)``
        Shapes:
            - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is head_hidden_size.
            - logit (:obj:`torch.FloatTensor`): :math:`(B, M)`, where M is action_shape.
            - tau (:obj:`torch.Tensor`):  :math:`(B, M, 1)`
        Examples:
            >>> model = QRDQN(64, 64)
            >>> inputs = torch.randn(4, 64)
            >>> outputs = model(inputs)
            >>> assert isinstance(outputs, dict)
            >>> assert outputs['logit'].shape == torch.Size([4, 64])
            >>> # default num_quantiles : int = 32
            >>> assert outputs['q'].shape == torch.Size([4, 64, 32])
            >>> assert outputs['tau'].shape == torch.Size([4, 32, 1])
        """
        x = self.encoder(x)
        x = self.head(x)
        return x


@MODEL_REGISTRY.register('iqn')
class IQN(nn.Module):
    """
    Overview:
        The neural network structure and computation graph of IQN, which combines distributional RL and DQN. \
        You can refer to paper Implicit Quantile Networks for Distributional Reinforcement Learning \
        https://arxiv.org/pdf/1806.06923.pdf for more details.
    Interfaces:
        ``__init__``, ``forward``
    """

    def __init__(
            self,
            obs_shape: Union[int, SequenceType],
            action_shape: Union[int, SequenceType],
            encoder_hidden_size_list: SequenceType = [128, 128, 64],
            head_hidden_size: Optional[int] = None,
            head_layer_num: int = 1,
            num_quantiles: int = 32,
            quantile_embedding_size: int = 128,
            activation: Optional[nn.Module] = nn.ReLU(),
            norm_type: Optional[str] = None
    ) -> None:
        """
        Overview:
            Initialize the IQN Model according to input arguments.
        Arguments:
            - obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape.
            - action_shape (:obj:`Union[int, SequenceType]`): Action space shape.
            - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``
            - head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to ``Head``.
            - head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output
            - num_quantiles (:obj:`int`): Number of quantiles in the prediction distribution.
            - activation (:obj:`Optional[nn.Module]`):
                The type of activation function to use in ``MLP`` the after ``layer_fn``,
                if ``None`` then default set to ``nn.ReLU()``
            - norm_type (:obj:`Optional[str]`):
                The type of normalization to use, see ``ding.torch_utils.fc_block`` for more details.
        """
        super(IQN, self).__init__()
        # For compatibility: 1, (1, ), [4, 32, 32]
        obs_shape, action_shape = squeeze(obs_shape), squeeze(action_shape)
        if head_hidden_size is None:
            head_hidden_size = encoder_hidden_size_list[-1]
        # FC Encoder
        if isinstance(obs_shape, int) or len(obs_shape) == 1:
            self.encoder = FCEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
        # Conv Encoder
        elif len(obs_shape) == 3:
            self.encoder = ConvEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
        else:
            raise RuntimeError(
                "not support obs_shape for pre-defined encoder: {}, please customize your own IQN".format(obs_shape)
            )
        # Head Type
        head_cls = QuantileHead
        multi_head = not isinstance(action_shape, int)
        if multi_head:
            self.head = MultiHead(
                head_cls,
                head_hidden_size,
                action_shape,
                layer_num=head_layer_num,
                num_quantiles=num_quantiles,
                quantile_embedding_size=quantile_embedding_size,
                activation=activation,
                norm_type=norm_type
            )
        else:
            self.head = head_cls(
                head_hidden_size,
                action_shape,
                head_layer_num,
                activation=activation,
                norm_type=norm_type,
                num_quantiles=num_quantiles,
                quantile_embedding_size=quantile_embedding_size,
            )

    def forward(self, x: torch.Tensor) -> Dict:
        """
        Overview:
            Use encoded embedding tensor to predict IQN's output.
            Parameter updates with IQN's MLPs forward setup.
        Arguments:
            - x (:obj:`torch.Tensor`):
                The encoded embedding tensor with ``(B, N=hidden_size)``.
        Returns:
            - outputs (:obj:`Dict`):
                Run with encoder and head. Return the result prediction dictionary.
        ReturnsKeys:
            - logit (:obj:`torch.Tensor`): Logit tensor with same size as input ``x``.
            - q (:obj:`torch.Tensor`): Q valye tensor tensor of size ``(num_quantiles, N, B)``
            - quantiles (:obj:`torch.Tensor`): quantiles tensor of size ``(quantile_embedding_size, 1)``
        Shapes:
            - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is head_hidden_size.
            - logit (:obj:`torch.FloatTensor`): :math:`(B, M)`, where M is action_shape
            - quantiles (:obj:`torch.Tensor`):  :math:`(P, 1)`, where P is quantile_embedding_size.
        Examples:
            >>> model = IQN(64, 64) # arguments: 'obs_shape' and 'action_shape'
            >>> inputs = torch.randn(4, 64)
            >>> outputs = model(inputs)
            >>> assert isinstance(outputs, dict)
            >>> assert outputs['logit'].shape == torch.Size([4, 64])
            >>> # default num_quantiles: int = 32
            >>> assert outputs['q'].shape == torch.Size([32, 4, 64]
            >>> # default quantile_embedding_size: int = 128
            >>> assert outputs['quantiles'].shape == torch.Size([128, 1])
        """
        x = self.encoder(x)
        x = self.head(x)
        return x


@MODEL_REGISTRY.register('fqf')
class FQF(nn.Module):
    """
    Overview:
        The neural network structure and computation graph of FQF, which combines distributional RL and DQN. \
        You can refer to paper Fully Parameterized Quantile Function for Distributional Reinforcement Learning \
        https://arxiv.org/pdf/1911.02140.pdf for more details.
    Interface:
        ``__init__``, ``forward``
    """

    def __init__(
            self,
            obs_shape: Union[int, SequenceType],
            action_shape: Union[int, SequenceType],
            encoder_hidden_size_list: SequenceType = [128, 128, 64],
            head_hidden_size: Optional[int] = None,
            head_layer_num: int = 1,
            num_quantiles: int = 32,
            quantile_embedding_size: int = 128,
            activation: Optional[nn.Module] = nn.ReLU(),
            norm_type: Optional[str] = None
    ) -> None:
        """
        Overview:
            Initialize the FQF Model according to input arguments.
        Arguments:
            - obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape.
            - action_shape (:obj:`Union[int, SequenceType]`): Action space shape.
            - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``
            - head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to ``Head``.
            - head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output
            - num_quantiles (:obj:`int`): Number of quantiles in the prediction distribution.
            - activation (:obj:`Optional[nn.Module]`):
                The type of activation function to use in ``MLP`` the after ``layer_fn``,
                if ``None`` then default set to ``nn.ReLU()``
            - norm_type (:obj:`Optional[str]`):
                The type of normalization to use, see ``ding.torch_utils.fc_block`` for more details.
        """
        super(FQF, self).__init__()
        # For compatibility: 1, (1, ), [4, 32, 32]
        obs_shape, action_shape = squeeze(obs_shape), squeeze(action_shape)
        if head_hidden_size is None:
            head_hidden_size = encoder_hidden_size_list[-1]
        # FC Encoder
        if isinstance(obs_shape, int) or len(obs_shape) == 1:
            self.encoder = FCEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
        # Conv Encoder
        elif len(obs_shape) == 3:
            self.encoder = ConvEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
        else:
            raise RuntimeError(
                "not support obs_shape for pre-defined encoder: {}, please customize your own FQF".format(obs_shape)
            )
        # Head Type
        head_cls = FQFHead
        multi_head = not isinstance(action_shape, int)
        if multi_head:
            self.head = MultiHead(
                head_cls,
                head_hidden_size,
                action_shape,
                layer_num=head_layer_num,
                num_quantiles=num_quantiles,
                quantile_embedding_size=quantile_embedding_size,
                activation=activation,
                norm_type=norm_type
            )
        else:
            self.head = head_cls(
                head_hidden_size,
                action_shape,
                head_layer_num,
                activation=activation,
                norm_type=norm_type,
                num_quantiles=num_quantiles,
                quantile_embedding_size=quantile_embedding_size,
            )

    def forward(self, x: torch.Tensor) -> Dict:
        """
        Overview:
            Use encoded embedding tensor to predict FQF's output.
            Parameter updates with FQF's MLPs forward setup.
        Arguments:
            - x (:obj:`torch.Tensor`):
                The encoded embedding tensor with ``(B, N=hidden_size)``.
        Returns:
            - outputs (:obj:`Dict`): Dict containing keywords ``logit`` (:obj:`torch.Tensor`), \
                    ``q`` (:obj:`torch.Tensor`), ``quantiles`` (:obj:`torch.Tensor`), \
                    ``quantiles_hats`` (:obj:`torch.Tensor`), \
                    ``q_tau_i`` (:obj:`torch.Tensor`), ``entropies`` (:obj:`torch.Tensor`).
        Shapes:
            - x: :math:`(B, N)`, where B is batch size and N is head_hidden_size.
            - logit: :math:`(B, M)`, where M is action_shape.
            - q: :math:`(B, num_quantiles, M)`.
            - quantiles: :math:`(B, num_quantiles + 1)`.
            - quantiles_hats: :math:`(B, num_quantiles)`.
            - q_tau_i: :math:`(B, num_quantiles - 1, M)`.
            - entropies: :math:`(B, 1)`.
        Examples:
            >>> model = FQF(64, 64) # arguments: 'obs_shape' and 'action_shape'
            >>> inputs = torch.randn(4, 64)
            >>> outputs = model(inputs)
            >>> assert isinstance(outputs, dict)
            >>> assert outputs['logit'].shape == torch.Size([4, 64])
            >>> # default num_quantiles: int = 32
            >>> assert outputs['q'].shape == torch.Size([4, 32, 64])
            >>> assert outputs['quantiles'].shape == torch.Size([4, 33])
            >>> assert outputs['quantiles_hats'].shape == torch.Size([4, 32])
            >>> assert outputs['q_tau_i'].shape == torch.Size([4, 31, 64])
            >>> assert outputs['quantiles'].shape == torch.Size([4, 1])
        """
        x = self.encoder(x)
        x = self.head(x)
        return x


@MODEL_REGISTRY.register('rainbowdqn')
class RainbowDQN(nn.Module):
    """
    Overview:
        The neural network structure and computation graph of RainbowDQN, which combines distributional RL and DQN. \
        You can refer to paper Rainbow: Combining Improvements in Deep Reinforcement Learning \
        https://arxiv.org/pdf/1710.02298.pdf for more details.
    Interfaces:
        ``__init__``, ``forward``

    .. note::
        RainbowDQN contains dueling architecture by default.
    """

    def __init__(
        self,
        obs_shape: Union[int, SequenceType],
        action_shape: Union[int, SequenceType],
        encoder_hidden_size_list: SequenceType = [128, 128, 64],
        head_hidden_size: Optional[int] = None,
        head_layer_num: int = 1,
        activation: Optional[nn.Module] = nn.ReLU(),
        norm_type: Optional[str] = None,
        v_min: Optional[float] = -10,
        v_max: Optional[float] = 10,
        n_atom: Optional[int] = 51,
    ) -> None:
        """
        Overview:
            Init the Rainbow Model according to arguments.
        Arguments:
            - obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape.
            - action_shape (:obj:`Union[int, SequenceType]`): Action space shape.
            - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``
            - head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to ``Head``.
            - head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output
            - activation (:obj:`Optional[nn.Module]`): The type of activation function to use in ``MLP`` the after \
                ``layer_fn``, if ``None`` then default set to ``nn.ReLU()``
            - norm_type (:obj:`Optional[str]`): The type of normalization to use, see ``ding.torch_utils.fc_block`` \
                for more details`
            - n_atom (:obj:`Optional[int]`): Number of atoms in the prediction distribution.
        """
        super(RainbowDQN, self).__init__()
        # For compatibility: 1, (1, ), [4, 32, 32]
        obs_shape, action_shape = squeeze(obs_shape), squeeze(action_shape)
        if head_hidden_size is None:
            head_hidden_size = encoder_hidden_size_list[-1]
        # FC Encoder
        if isinstance(obs_shape, int) or len(obs_shape) == 1:
            self.encoder = FCEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
        # Conv Encoder
        elif len(obs_shape) == 3:
            self.encoder = ConvEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
        else:
            raise RuntimeError(
                "not support obs_shape for pre-defined encoder: {}, please customize your own RainbowDQN".
                format(obs_shape)
            )
        # Head Type
        multi_head = not isinstance(action_shape, int)
        if multi_head:
            self.head = MultiHead(
                RainbowHead,
                head_hidden_size,
                action_shape,
                layer_num=head_layer_num,
                activation=activation,
                norm_type=norm_type,
                n_atom=n_atom,
                v_min=v_min,
                v_max=v_max,
            )
        else:
            self.head = RainbowHead(
                head_hidden_size,
                action_shape,
                head_layer_num,
                activation=activation,
                norm_type=norm_type,
                n_atom=n_atom,
                v_min=v_min,
                v_max=v_max,
            )

    def forward(self, x: torch.Tensor) -> Dict:
        """
        Overview:
            Use observation tensor to predict Rainbow output.
            Parameter updates with Rainbow's MLPs forward setup.
        Arguments:
            - x (:obj:`torch.Tensor`):
                The encoded embedding tensor with ``(B, N=hidden_size)``.
        Returns:
            - outputs (:obj:`Dict`):
                Run ``MLP`` with ``RainbowHead`` setups and return the result prediction dictionary.
        ReturnsKeys:
            - logit (:obj:`torch.Tensor`): Logit tensor with same size as input ``x``.
            - distribution (:obj:`torch.Tensor`): Distribution tensor of size ``(B, N, n_atom)``
        Shapes:
            - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is head_hidden_size.
            - logit (:obj:`torch.FloatTensor`): :math:`(B, M)`, where M is action_shape.
            - distribution(:obj:`torch.FloatTensor`): :math:`(B, M, P)`, where P is n_atom.
        Examples:
            >>> model = RainbowDQN(64, 64) # arguments: 'obs_shape' and 'action_shape'
            >>> inputs = torch.randn(4, 64)
            >>> outputs = model(inputs)
            >>> assert isinstance(outputs, dict)
            >>> assert outputs['logit'].shape == torch.Size([4, 64])
            >>> # default n_atom: int =51
            >>> assert outputs['distribution'].shape == torch.Size([4, 64, 51])
        """
        x = self.encoder(x)
        x = self.head(x)
        return x


def parallel_wrapper(forward_fn: Callable) -> Callable:
    """
    Overview:
        Process timestep T and batch_size B at the same time, in other words, treat different timestep data as
        different trajectories in a batch.
    Arguments:
        - forward_fn (:obj:`Callable`): Normal ``nn.Module`` 's forward function.
    Returns:
        - wrapper (:obj:`Callable`): Wrapped function.
    """

    def wrapper(x: torch.Tensor) -> Union[torch.Tensor, List[torch.Tensor]]:
        T, B = x.shape[:2]

        def reshape(d):
            if isinstance(d, list):
                d = [reshape(t) for t in d]
            elif isinstance(d, dict):
                d = {k: reshape(v) for k, v in d.items()}
            else:
                d = d.reshape(T, B, *d.shape[1:])
            return d

        # NOTE(rjy): the initial input shape will be (T, B, N),
        #            means encoder or head should process B trajectorys, each trajectory has T timestep,
        #            but T and B dimension can be both treated as batch_size in encoder and head,
        #            i.e., independent and parallel processing,
        #            so here we need such fn to reshape for encoder or head
        x = x.reshape(T * B, *x.shape[2:])
        x = forward_fn(x)
        x = reshape(x)
        return x

    return wrapper


@MODEL_REGISTRY.register('drqn')
class DRQN(nn.Module):
    """
    Overview:
        The DRQN (Deep Recurrent Q-Network) is a neural network model combining DQN with RNN to handle sequential
        data and partially observable environments. It consists of three main components: ``encoder``, ``rnn``,
        and ``head``.
        - **Encoder**: Extracts features from various observation inputs.
        - **RNN**: Processes sequential observations and other data.
        - **Head**: Computes Q-values for each action dimension.

    Interfaces:
        ``__init__``, ``forward``.

    .. note::
        The current implementation supports:
        - Two encoder types: ``FCEncoder`` and ``ConvEncoder``.
        - Two head types: ``DiscreteHead`` and ``DuelingHead``.
        - Three RNN types: ``normal (LSTM with LayerNorm)``, ``pytorch`` (PyTorch's native LSTM), and ``gru``.
        You can extend the model by customizing your own encoder, RNN, or head by inheriting this class.
    """

    def __init__(
            self,
            obs_shape: Union[int, SequenceType],
            action_shape: Union[int, SequenceType],
            encoder_hidden_size_list: SequenceType = [128, 128, 64],
            dueling: bool = True,
            head_hidden_size: Optional[int] = None,
            head_layer_num: int = 1,
            lstm_type: Optional[str] = 'normal',
            activation: Optional[nn.Module] = nn.ReLU(),
            norm_type: Optional[str] = None,
            res_link: bool = False
    ) -> None:
        """
        Overview:
            Initialize the DRQN model with specified parameters.
        Arguments:
            - obs_shape (:obj:`Union[int, SequenceType]`): Shape of the observation space, e.g., 8 or [4, 84, 84].
            - action_shape (:obj:`Union[int, SequenceType]`): Shape of the action space, e.g., 6 or [2, 3, 3].
            - encoder_hidden_size_list (:obj:`SequenceType`): List of hidden sizes for the encoder. The last element \
                must match ``head_hidden_size``.
            - dueling (:obj:`Optional[bool]`): Use ``DuelingHead`` if True, otherwise use ``DiscreteHead``.
            - head_hidden_size (:obj:`Optional[int]`): Hidden size for the head network. Defaults to the last \
                element of ``encoder_hidden_size_list`` if None.
            - head_layer_num (:obj:`int`): Number of layers in the head network to compute Q-value outputs.
            - lstm_type (:obj:`Optional[str]`): Type of RNN module. Supported types are ``normal``, ``pytorch``, \
                and ``gru``.
            - activation (:obj:`Optional[nn.Module]`): Activation function used in the network. Defaults to \
                ``nn.ReLU()``.
            - norm_type (:obj:`Optional[str]`): Normalization type for the networks. Supported types are: \
                ['BN', 'IN', 'SyncBN', 'LN']. See ``ding.torch_utils.fc_block`` for more details.
            - res_link (:obj:`bool`): Enables residual connections between single-frame data and sequential data. \
                Defaults to False.
        """
        super(DRQN, self).__init__()
        # Compatibility for obs_shape/action_shape: Handles scalar, tuple, or multi-dimensional inputs.
        obs_shape, action_shape = squeeze(obs_shape), squeeze(action_shape)
        if head_hidden_size is None:
            head_hidden_size = encoder_hidden_size_list[-1]

        # Encoder: Determines the encoder type based on the observation shape.
        if isinstance(obs_shape, int) or len(obs_shape) == 1:
            # FC Encoder
            self.encoder = FCEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
        elif len(obs_shape) == 3:
            # Conv Encoder
            self.encoder = ConvEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
        else:
            raise RuntimeError(
                f"Unsupported obs_shape for pre-defined encoder: {obs_shape}. Please customize your own DRQN."
            )

        # RNN: Initializes the RNN module based on the specified lstm_type.
        self.rnn = get_lstm(lstm_type, input_size=head_hidden_size, hidden_size=head_hidden_size)
        self.res_link = res_link

        # Head: Determines the head type (Dueling or Discrete) and its configuration.
        if dueling:
            head_cls = DuelingHead
        else:
            head_cls = DiscreteHead
        multi_head = not isinstance(action_shape, int)
        if multi_head:
            self.head = MultiHead(
                head_cls,
                head_hidden_size,
                action_shape,
                layer_num=head_layer_num,
                activation=activation,
                norm_type=norm_type
            )
        else:
            self.head = head_cls(
                head_hidden_size, action_shape, head_layer_num, activation=activation, norm_type=norm_type
            )

    def forward(self, inputs: Dict, inference: bool = False, saved_state_timesteps: Optional[list] = None) -> Dict:
        """
        Overview:
            Defines the forward pass of the DRQN model. Takes observation and previous RNN states as inputs \
            and predicts Q-values.
        Arguments:
            - inputs (:obj:`Dict`): Input data dictionary containing observation and previous RNN state.
            - inference (:obj:`bool`): If True, unrolls one timestep (used during evaluation). If False, unrolls \
                the entire sequence (used during training).
            - saved_state_timesteps (:obj:`Optional[list]`): When inference is False, specifies the timesteps \
                whose hidden states are saved and returned.
        ArgumentsKeys:
            - obs (:obj:`torch.Tensor`): Raw observation tensor.
            - prev_state (:obj:`list`): Previous RNN state tensor, structure depends on ``lstm_type``.
        Returns:
            - outputs (:obj:`Dict`): The output of DRQN's forward, including logit (q_value) and next state.
        ReturnsKeys:
            - logit (:obj:`torch.Tensor`): Discrete Q-value output for each action dimension.
            - next_state (:obj:`list`): Next RNN state tensor.
        Shapes:
            - obs (:obj:`torch.Tensor`): :math:`(B, N)` where B is batch size and N is ``obs_shape``.
            - logit (:obj:`torch.Tensor`): :math:`(B, M)` where B is batch size and M is ``action_shape``.
        Examples:
            >>> # Initialize input keys
            >>> prev_state = [[torch.randn(1, 1, 64) for __ in range(2)] for _ in range(4)] # B=4
            >>> obs = torch.randn(4,64)
            >>> model = DRQN(64, 64) # arguments: 'obs_shape' and 'action_shape'
            >>> outputs = model({'obs': inputs, 'prev_state': prev_state}, inference=True)
            >>> # Validate output keys and shapes
            >>> assert isinstance(outputs, dict)
            >>> assert outputs['logit'].shape == (4, 64)
            >>> assert len(outputs['next_state']) == 4
            >>> assert all([len(t) == 2 for t in outputs['next_state']])
            >>> assert all([t[0].shape == (1, 1, 64) for t in outputs['next_state']])
        """

        x, prev_state = inputs['obs'], inputs['prev_state']
        # Forward pass: Encoder -> RNN -> Head
        # in most situations, set inference=True when evaluate and inference=False when training
        # Inference mode: Processes one timestep (seq_len=1).
        if inference:
            x = self.encoder(x)
            if self.res_link:
                a = x
            x = x.unsqueeze(0)  # for rnn input, put the seq_len of x as 1 instead of none.
            # prev_state: DataType: List[Tuple[torch.Tensor]]; Initially, it is a list of None
            x, next_state = self.rnn(x, prev_state)
            x = x.squeeze(0)  # to delete the seq_len dim to match head network input
            if self.res_link:
                x = x + a
            x = self.head(x)
            x['next_state'] = next_state
            return x
        # Training mode: Processes the entire sequence.
        else:
            # In order to better explain why rnn needs saved_state and which states need to be stored,
            # let's take r2d2 as an example
            # in r2d2,
            # 1) data['burnin_nstep_obs'] = data['obs'][:bs + self._nstep]
            # 2) data['main_obs'] = data['obs'][bs:-self._nstep]
            # 3) data['target_obs'] = data['obs'][bs + self._nstep:]
            assert len(x.shape) in [3, 5], f"Expected shape (T, B, N) or (T, B, C, H, W), got {x.shape}"
            x = parallel_wrapper(self.encoder)(x)  # (T, B, N)
            if self.res_link:
                a = x
            # lstm_embedding stores all hidden_state
            lstm_embedding = []
            # TODO(nyz) how to deal with hidden_size key-value
            hidden_state_list = []

            if saved_state_timesteps is not None:
                saved_state = []
            for t in range(x.shape[0]):  # Iterate over timesteps (T).
                # use x[t:t+1] but not x[t] can keep the original dimension
                output, prev_state = self.rnn(x[t:t + 1], prev_state)  # RNN step output: (1, B, hidden_size)
                if saved_state_timesteps is not None and t + 1 in saved_state_timesteps:
                    saved_state.append(prev_state)
                lstm_embedding.append(output)
                hidden_state = [p['h'] for p in prev_state]
                # only keep ht, {list: x.shape[0]{Tensor:(1, batch_size, head_hidden_size)}}
                hidden_state_list.append(torch.cat(hidden_state, dim=1))
            x = torch.cat(lstm_embedding, 0)  # (T, B, head_hidden_size)
            if self.res_link:
                x = x + a
            x = parallel_wrapper(self.head)(x)  # (T, B, action_shape)
            # x['next_state'] is the hidden state of the last timestep inputted to lstm
            # the last timestep state including the hidden state (h) and the cell state (c)
            # shape: {list: B{dict: 2{Tensor:(1, 1, head_hidden_size}}}
            x['next_state'] = prev_state
            # all hidden state h, this returns a tensor of the dim: seq_len*batch_size*head_hidden_size
            # This key is used in qtran, the algorithm requires to retain all h_{t} during training
            x['hidden_state'] = torch.cat(hidden_state_list, dim=0)
            if saved_state_timesteps is not None:
                # the selected saved hidden states, including the hidden state (h) and the cell state (c)
                # in r2d2, set 'saved_hidden_state_timesteps=[self._burnin_step, self._burnin_step + self._nstep]',
                # then saved_state will record the hidden_state for main_obs and target_obs to
                # initialize their lstm (h c)
                x['saved_state'] = saved_state
            return x


@MODEL_REGISTRY.register('gtrxldqn')
class GTrXLDQN(nn.Module):
    """
    Overview:
        The neural network structure and computation graph of Gated Transformer-XL DQN algorithm, which is the \
        enhanced version of DRQN, using Transformer-XL to improve long-term sequential modelling ability. The \
        GTrXL-DQN is composed of three parts: ``encoder``, ``head`` and ``core``. The ``encoder`` is used to extract \
        the feature from various observation, the ``core`` is used to process the sequential observation and other \
        data, and the ``head`` is used to compute the Q value of each action dimension.
    Interfaces:
        ``__init__``, ``forward``, ``reset_memory``, ``get_memory`` .
    """

    def __init__(
        self,
        obs_shape: Union[int, SequenceType],
        action_shape: Union[int, SequenceType],
        head_layer_num: int = 1,
        att_head_dim: int = 16,
        hidden_size: int = 16,
        att_head_num: int = 2,
        att_mlp_num: int = 2,
        att_layer_num: int = 3,
        memory_len: int = 64,
        activation: Optional[nn.Module] = nn.ReLU(),
        head_norm_type: Optional[str] = None,
        dropout: float = 0.,
        gru_gating: bool = True,
        gru_bias: float = 2.,
        dueling: bool = True,
        encoder_hidden_size_list: SequenceType = [128, 128, 256],
        encoder_norm_type: Optional[str] = None,
    ) -> None:
        """
        Overview:
            Initialize the GTrXLDQN model accoding to corresponding input arguments.

        .. tip::
            You can refer to GTrXl class in ``ding.torch_utils.network.gtrxl`` for more details about the input \
            arguments.

        Arguments:
            - obs_shape (:obj:`Union[int, SequenceType]`): Used by Transformer. Observation's space.
            - action_shape (:obj:Union[int, SequenceType]): Used by Head. Action's space.
            - head_layer_num (:obj:`int`): Used by Head. Number of layers.
            - att_head_dim (:obj:`int`): Used by Transformer.
            - hidden_size (:obj:`int`): Used by Transformer and Head.
            - att_head_num (:obj:`int`): Used by Transformer.
            - att_mlp_num (:obj:`int`): Used by Transformer.
            - att_layer_num (:obj:`int`): Used by Transformer.
            - memory_len (:obj:`int`): Used by Transformer.
            - activation (:obj:`Optional[nn.Module]`): Used by Transformer and Head. if ``None`` then default set to \
                ``nn.ReLU()``.
            - head_norm_type (:obj:`Optional[str]`): Used by Head. The type of normalization to use, see \
                ``ding.torch_utils.fc_block`` for more details`.
            - dropout (:obj:`bool`): Used by Transformer.
            - gru_gating (:obj:`bool`): Used by Transformer.
            - gru_bias (:obj:`float`): Used by Transformer.
            - dueling (:obj:`bool`): Used by Head. Make the head dueling.
            - encoder_hidden_size_list(:obj:`SequenceType`): Used by Encoder. The collection of ``hidden_size`` if \
                using a custom convolutional encoder.
            - encoder_norm_type (:obj:`Optional[str]`): Used by Encoder. The type of normalization to use, see \
             ``ding.torch_utils.fc_block`` for more details`.
        """
        super(GTrXLDQN, self).__init__()
        self.core = GTrXL(
            input_dim=obs_shape,
            head_dim=att_head_dim,
            embedding_dim=hidden_size,
            head_num=att_head_num,
            mlp_num=att_mlp_num,
            layer_num=att_layer_num,
            memory_len=memory_len,
            activation=activation,
            dropout_ratio=dropout,
            gru_gating=gru_gating,
            gru_bias=gru_bias,
        )

        # for vector obs, use Identity Encoder, i.e. pass
        if isinstance(obs_shape, int) or len(obs_shape) == 1:
            pass
        # replace the embedding layer of Transformer with Conv Encoder
        elif len(obs_shape) == 3:
            assert encoder_hidden_size_list[-1] == hidden_size
            self.obs_encoder = ConvEncoder(
                obs_shape, encoder_hidden_size_list, activation=activation, norm_type=encoder_norm_type
            )
            self.dropout = nn.Dropout(dropout)
            self.core.use_embedding_layer = False
        else:
            raise RuntimeError(
                "not support obs_shape for pre-defined encoder: {}, please customize your own GTrXL".format(obs_shape)
            )
        # Head Type
        if dueling:
            head_cls = DuelingHead
        else:
            head_cls = DiscreteHead
        multi_head = not isinstance(action_shape, int)
        if multi_head:
            self.head = MultiHead(
                head_cls,
                hidden_size,
                action_shape,
                layer_num=head_layer_num,
                activation=activation,
                norm_type=head_norm_type
            )
        else:
            self.head = head_cls(
                hidden_size, action_shape, head_layer_num, activation=activation, norm_type=head_norm_type
            )

    def forward(self, x: torch.Tensor) -> Dict:
        """
        Overview:
            Let input tensor go through GTrXl and the Head sequentially.
        Arguments:
            - x (:obj:`torch.Tensor`): input tensor of shape (seq_len, bs, obs_shape).
        Returns:
            - out (:obj:`Dict`): run ``GTrXL`` with ``DiscreteHead`` setups and return the result prediction dictionary.
        ReturnKeys:
            - logit (:obj:`torch.Tensor`): discrete Q-value output of each action dimension, shape is (B, action_space).
            - memory (:obj:`torch.Tensor`): memory tensor of size ``(bs x layer_num+1 x memory_len x embedding_dim)``.
            - transformer_out (:obj:`torch.Tensor`): output tensor of transformer with same size as input ``x``.
        Examples:
            >>> # Init input's Keys:
            >>> obs_dim, seq_len, bs, action_dim = 128, 64, 32, 4
            >>> obs = torch.rand(seq_len, bs, obs_dim)
            >>> model = GTrXLDQN(obs_dim, action_dim)
            >>> outputs = model(obs)
            >>> assert isinstance(outputs, dict)
        """
        if len(x.shape) == 5:
            # 3d obs: cur_seq, bs, ch, h, w
            x_ = x.reshape([x.shape[0] * x.shape[1]] + list(x.shape[-3:]))
            x_ = self.dropout(self.obs_encoder(x_))
            x = x_.reshape(x.shape[0], x.shape[1], -1)
        o1 = self.core(x)
        out = self.head(o1['logit'])
        # layer_num+1 x memory_len x bs embedding_dim -> bs x layer_num+1 x memory_len x embedding_dim
        out['memory'] = o1['memory'].permute((2, 0, 1, 3)).contiguous()
        out['transformer_out'] = o1['logit']  # output of gtrxl, out['logit'] is final output
        return out

    def reset_memory(self, batch_size: Optional[int] = None, state: Optional[torch.Tensor] = None) -> None:
        """
        Overview:
            Clear or reset the memory of GTrXL.
        Arguments:
            - batch_size (:obj:`Optional[int]`): The number of samples in a training batch.
            - state (:obj:`Optional[torch.Tensor]`): The input memory data, whose shape is \
                (layer_num, memory_len, bs, embedding_dim).
        """
        self.core.reset_memory(batch_size, state)

    def get_memory(self) -> Optional[torch.Tensor]:
        """
        Overview:
            Return the memory of GTrXL.
        Returns:
            - memory: (:obj:`Optional[torch.Tensor]`): output memory or None if memory has not been initialized, \
                whose shape is (layer_num, memory_len, bs, embedding_dim).
        """
        return self.core.get_memory()
